iT邦幫忙

2024 iThome 鐵人賽

DAY 6
0
自我挑戰組

菜鳥AI工程師給碩班學弟妹的挑戰系列 第 6

[Day6] 細講pytorch Dataset - 1

  • 分享至 

  • xImage
  •  

今天我們終於要來細談pytorch當中的Dataset,因為本人是在語音AI公司上班,所以對於影像方面並沒有研究,以下例子主要為文字與聲音。

強烈建議跟著一行行寫,不要直接複製貼上,寫完你就會懂了

1.我們先來看底下的一段code,最一開始主要有三項東東:

  1. init: 主要會把一些初始化的東西寫在這裡,比如說
    文字方面: 初始化 tokenizer
    聲音方面: 初始化 stft, fbank, mel spectrogram
    影像方面: 初始化 transform
    額外參數: 帶了txt_path進來

  2. len: 基本上就是回傳你資料的size,會跟之後batch size相除,得到你需要跑幾個step

  3. getitem: 最基本會回傳data, label

    data
    文字方面: 已經tokenizer完的tensor array
    聲音方面: torchaduio完的tensor array, 經過stft, bank…計算完的tensor array
    影像方面: 經過transform resize完的array

    label
    文字方面: 看應用的場景,分類就回傳0, 1…,摘要回傳tokenizer完的tensor array
    聲音方面: ASR回傳文字的tokenizer,enhance回傳clean的audio,分類就回傳0, 1
    影像方面: 分類回傳0, 1,切割回傳mask資訊

from torch.utils.data import Dataset

class CustomDataset(Dataset):
    def __init__(self, txt_path):
        self.data = []
        self.get_data(txt_path)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        pass

    def get_data(self, txt_path):
        with open(txt_path, 'r') as f_i:
            lines = f_i.readlines()
            self.data = [line.strip() for line in lines]

那當中的get_data主要就是把你要的東西讀進來,通常格式會採用txt或json

以聲音舉例:
前面是音檔位置中間用'|'隔開,或者是用其他特殊符號隔開,後面接label

/path/to/audio|label
/path/to/audio2|label

{"audio_path": "/path/to/audio", "label": "label"}
{"audio_path": "/path/to/audio2", "label": "label"}

2.再來我們從get_item讀資料

以我最近在做的語言偵測作範例,我們透過self.data[idx]將那行資料讀進來,我一開始自學很疑惑哪裡來的idx,後來才搞懂這idx是給Dataloader呼叫時使用,假設batch size為4,且不shuffle,那麼一個batch,他就會idx = 0, idx = 1, idx = 2, idx = 3,來獲取每筆資料。
寫到這時,我們先透過in enumerate一個個遞迴來判斷這部分有沒有問題。

from torch.utils.data import Dataset

class CustomDataset(Dataset):
    def __init__(self, txt_path):
        self.data = []
        self.get_data(txt_path)
        self.label_mapping = {
            'ZH': 0,
            'EN': 1,
            'TW': 2,
            'HAK': 3
        }

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        data = self.data[idx]
        path, label = data.split('|')
        label = self.label_mapping[label]

        return path, label
    def get_data(self, txt_path):
        with open(txt_path, 'r') as f_i:
            lines = f_i.readlines()
            self.data = [line.strip() for line in lines]

if __name__ == "__main__":
    unit_test = CustomDataset('unit_test.txt')

    for idx, (path, label) in enumerate(unit_test):
        print(f'path: {path}, label: {label}')

unit_text.txt為

/path/to/audio|EN
/path/to/audio2|TW

輸出結果為
https://ithelp.ithome.com.tw/upload/images/20240810/20168446DlBCaMw0dW.png

今天就先到這裡囉~ 明天接續講怎麼處理音檔。


上一篇
[Day5] 爬蟲收集音檔(+ yt-dlp) - 3
下一篇
[Day7] 細講pytorch Dataset - 2
系列文
菜鳥AI工程師給碩班學弟妹的挑戰30
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言